# ------------------------------------------------------------------------------------------------
### Custom synthetic difference-in-differences plotting functions
# ------------------------------------------------------------------------------------------------

contract3 = function(X, v) {
  stopifnot(length(dim(X)) == 3, dim(X)[3] == length(v))
  out = array(0, dim = dim(X)[1:2])
  if (length(v) == 0) { return(out) }
  for (ii in 1:length(v)) {
    out = out + v[ii] * X[, , ii]
  }
  return(out)
}

synthdid_plot_custom = function(estimates, treated.name = 'treated', control.name = 'synthetic control', 
                         spaghetti.units = c(), spaghetti.matrices = NULL,
                         facet = NULL, facet.vertical = TRUE, lambda.comparable = !is.null(facet), overlay = 0,
                         lambda.plot.scale = 3, trajectory.linetype = 1, effect.curvature = .3, line.width = 1.5, guide.linetype = 2, point.size = 2.5,
                         trajectory.alpha = 1, diagram.alpha = 1, effect.alpha = 1, onset.alpha = 1, ci.alpha=1,
                         spaghetti.line.width = .2, spaghetti.label.size = 2, 
                         spaghetti.line.alpha = 1, spaghetti.label.alpha = 1,
                         se.method='jackknife', alpha.multiplier = NULL) {
  if (requireNamespace("ggplot2", quietly = TRUE)) {
    .ignore <- tryCatch(attachNamespace("ggplot2"), error = function(e) e)
  } else {
    stop("Plotting requires the package `ggplot2`. Install it to use this function.")
  }
  if (class(estimates) == 'synthdid_estimate') { estimates = list(estimates) }
  if (is.null(names(estimates))) { names(estimates) = sprintf('estimate %d', 1:length(estimates)) }
  if (is.null(alpha.multiplier)) { alpha.multiplier = rep(1, length(estimates)) }
  if (!is.null(spaghetti.matrices) && length(spaghetti.matrices) != length(estimates)) { stop('spaghetti.matrices must be the same length as estimates') }
  multiple.frames = length(overlay) > 1
  treated = 1
  control = 2
  groups = factor(c(control, treated), labels = c(control.name, treated.name))
  estimate.factors = factor(1:(length(estimates) + 1), labels = c(treated.name, names(estimates)))
  facet_factors = if (is.null(facet)) {
    factor(1:length(estimates), labels = names(estimates))
  } else {
    factor(facet, levels = 1:length(unique(facet)), labels = unique(facet))
  }
  grid = expand.grid(estimate = 1:length(estimates), overlay = 1:length(overlay))
  plot.descriptions = lapply(1:nrow(grid), function(row) {
    est = estimates[[grid$estimate[row]]]
    over = overlay[grid$overlay[row]]
    se = if(se.method == 'none') { NA } else { sqrt(vcov(est, method=se.method)) }
    ci_low = if(se.method == 'none') { NA } else { est - 1.96*se}
    ci_high = if(se.method == 'none') { NA } else { est + 1.96*se}
    setup = attr(est, 'setup')
    weights = attr(est, 'weights')
    Y = setup$Y - contract3(setup$X, weights$beta)
    N0 = setup$N0; N1 = nrow(Y) - N0
    T0 = setup$T0; T1 = ncol(Y) - T0
    
    lambda.synth = c(weights$lambda, rep(0, T1))
    lambda.target = c(rep(0, T0), rep(1 / T1, T1))
    omega.synth = c(weights$omega, rep(0, N1))
    omega.target = c(rep(0, N0), rep(1 / N1, N1))
    
    # pull estimate-specific overlay from attribute if present
    # if we're given a synthetic control estimate or overlay is one, take note: we'll plot it differently
    if (!is.null(attr(est, 'overlay'))) { over = attr(est, 'overlay') } 
    is.sc = all(weights$lambda == 0) || over==1
    
    intercept.offset = over * c((omega.target - omega.synth) %*% Y %*% lambda.synth)
    obs.trajectory = as.numeric(omega.target %*% Y)
    syn.trajectory = as.numeric(omega.synth %*% Y) + intercept.offset
    spaghetti.trajectories = Y[rownames(Y) %in% spaghetti.units, , drop=FALSE] 
    if(!is.null(spaghetti.matrices)) {
      more.spaghetti.trajectories = spaghetti.matrices[[grid$estimate[row]]]	
      if(ncol(more.spaghetti.trajectories) != ncol(Y)) { stop('The elements of spaghetti.matrices must be matrices with the same number of columns as Y') }
      if(is.null(rownames(more.spaghetti.trajectories))) { stop('The elements of the list spaghetti.matrices must have named rows') }
      spaghetti.trajectories = rbind(spaghetti.trajectories, more.spaghetti.trajectories)
    } 
    
    treated.post = omega.target %*% Y %*% lambda.target
    treated.pre = omega.target %*% Y %*% lambda.synth
    control.post = omega.synth %*% Y %*% lambda.target + intercept.offset
    control.pre = omega.synth %*% Y %*% lambda.synth + intercept.offset
    sdid.post = as.numeric(control.post + treated.pre - control.pre)
    
    time = as.numeric(timesteps(Y))
    if (length(time) == 0 || !all(is.finite(time))) { time = 1:(T0 + T1) }
    pre.time = lambda.synth %*% time
    post.time = lambda.target %*% time
    
    # construct objects on graph
    lines = data.frame(x = rep(time, 2),
                       y = c(obs.trajectory, syn.trajectory),
                       color = rep(groups[c(treated, control)], each = length(time)))
    ci = data.frame(x = rep(time, 2),
                    ymin = c(obs.trajectory, syn.trajectory),
                    ymax = c(obs.trajectory, syn.trajectory),
                    color = rep(groups[c(treated, control)], each = length(time)))
    points = data.frame(x = c(post.time, post.time), y = c(treated.post, sdid.post), color = groups[c(treated, control)])
    did.points = data.frame(x = c(pre.time, pre.time, post.time, post.time),
                            y = c(treated.pre, control.pre, control.post, treated.post),
                            color = groups[c(treated, control, control, treated)])
    did.segments = data.frame(x = c(pre.time, pre.time),
                              xend = c(post.time, post.time),
                              y = c(control.pre, treated.pre),
                              yend = c(control.post, treated.post),
                              color = groups[c(control, treated)])
    hallucinated.segments = data.frame(x = pre.time, xend = post.time, y = treated.pre, yend = sdid.post)
    guide.segments = data.frame(x = c(pre.time, post.time),
                                xend = c(pre.time, post.time),
                                y = c(control.pre, control.post),
                                yend = c(treated.pre, sdid.post))
    arrows = data.frame(x = post.time, xend = post.time, y = sdid.post, yend = treated.post, 
                        xscale = max(time) - post.time, color = groups[control])
    ub.arrows = data.frame(x = post.time, xend = post.time, y = sdid.post + 1.96*se, yend = treated.post, 
                           xscale = max(time) - post.time, color = groups[control])
    lb.arrows = data.frame(x = post.time, xend = post.time, y = sdid.post - 1.96*se, yend = treated.post, 
                           xscale = max(time) - post.time, color = groups[control])
    spaghetti.lines = data.frame(x=rep(time, nrow(spaghetti.trajectories)),
                                 y=as.vector(t(spaghetti.trajectories)),
                                 unit=rep(rownames(spaghetti.trajectories), each=length(time)))
    spaghetti.labels = data.frame(x=rep(time[1], nrow(spaghetti.trajectories)),
                                  y=as.vector(spaghetti.trajectories[,1]), 
                                  unit=rownames(spaghetti.trajectories))
    
    
    T0s = attr(est, 'T0s')
    if (!is.null(T0s)) {
      vlines = data.frame(xintercept = time[T0s])
    } else {
      vlines = data.frame(xintercept = time[T0])
    }
    
    if (lambda.comparable) {
      height = (max(c(obs.trajectory)) - min(c(obs.trajectory))) / lambda.plot.scale
      bottom = min(c(obs.trajectory)) - height
      ribbons = data.frame(x = time[1:T0], ymin = rep(bottom, T0), ymax = bottom + height * lambda.synth[1:T0], color = groups[control])
    } else {
      height = (max(c(obs.trajectory, syn.trajectory)) - min(c(obs.trajectory, syn.trajectory))) / lambda.plot.scale
      bottom = min(c(obs.trajectory, syn.trajectory)) - height
      ribbons = data.frame(x = time[1:T0], ymin = rep(bottom, T0), ymax = bottom + height * lambda.synth[1:T0] / max(lambda.synth), color = groups[control])
    }
    elements = list(lines = lines, points = points, did.segments = did.segments, did.points = did.points,
                    hallucinated.segments = hallucinated.segments, guide.segments = guide.segments,
                    arrows = arrows, lb.arrows = lb.arrows, ub.arrows = ub.arrows, spaghetti.lines=spaghetti.lines, spaghetti.labels=spaghetti.labels,
                    vlines = vlines, ribbons = ribbons)
    lapply(elements, function(x) {
      if(nrow(x) > 0) { 
        x$frame = over
        x$is.sc = is.sc
        x$estimate = estimate.factors[grid$estimate[row] + 1] # offset because the treated pseudo-estimate factor is first
      }
      x
    })
  })
  
  
  one.per.facet = length(unique(facet_factors)) == length(facet_factors)
  concatenate.field = function(field) {
    do.call(rbind, lapply(plot.descriptions, function(desc) {
      element = desc[[field]]
      estimate.factor = element$estimate[1]
      element$facet = facet_factors[as.integer(estimate.factor) - 1] # offset because the treated pseudo-estimate factor is first
      element$show = alpha.multiplier[as.integer(element$estimate) - 1] # "
      element$show[element$color == groups[treated]] = 1 # show treated observations
      # if there are multiple plots per facet, color by estimator rather than by treatment/control
      # make all treated observations the same color, assuming that we're using only one treated observation per facet
      if (!one.per.facet && 'color' %in% colnames(element)) {
        color = element$estimate
        color[element$color == groups[treated]] = estimate.factors[1] # treated `estimate factor'
        element$color = color
      }
      # if there are multiple plots per facet, curve treatment effect arrows so they don't lie on top of one another
      element
    }))
  }
  conc = lapply(names(plot.descriptions[[1]]), concatenate.field)
  names(conc) = names(plot.descriptions[[1]])
  no.sc = function(x) { x[!x$is.sc, ] }
  
  # invoke the geom with 'frame=frame' included in the aesthetic only if there are multiple frames
  # This cuts down on warnings without restricting function.
  #   ggplotly understands frame and uses it to create animations if there are multiple
  #   ggplot2's display doesn't and warns 'ignoring unknown aesthetic' if you include it
  # returns geom(aes, data=data, ...) where aes=base.aes if there's a single frame
  #					    aes=(base.aes, frame=frame) if there's more than one
  with.frame = function(geom, base.aes, data, ...) {
    new.aes = if(multiple.frames) { modifyList(base.aes, aes(frame=frame)) } else { base.aes } 
    do.call(geom, c(list(new.aes, data=data), list(...))) 
  }
  
  p = ggplot() +
    with.frame(geom_point, aes(x = x, y = y, color = color, shape = color), data = conc$lines, size = 3) +
    with.frame(geom_line, aes(x = x, y = y, color = color, #alpha = trajectory.alpha * show
                              ), data = conc$lines, 
               linetype = trajectory.linetype, size = line.width) +
#   with.frame(geom_pointrange, aes(x = x, y = y, ymin = ymin, ymax = ymax, color = color), data = conc$lines) +
#    with.frame(geom_point, aes(x = x, y = y, color = color, alpha = diagram.alpha * show), data = conc$points, 
#               shape = 21, size = point.size) #+
#    with.frame(geom_point, aes(x = x, y = y, color = color, alpha = diagram.alpha * show), data = no.sc(conc$did.points), 
#               size = point.size) +
#    with.frame(geom_segment, aes(x = x, xend = xend, y = y, yend = yend, color = color, alpha = diagram.alpha * show), data = no.sc(conc$did.segments), 
#               size = line.width) +
#    with.frame(geom_segment, aes(x = x, xend = xend, y = y, yend = yend, group = estimate, alpha = .6 * diagram.alpha * show), data = no.sc(conc$hallucinated.segments), 
#               linetype = guide.linetype, size = line.width, color = 'black') +
#    with.frame(geom_segment, aes(x = x, xend = xend, y = y, yend = yend, group = estimate, alpha = .5 * diagram.alpha * show), data = no.sc(conc$guide.segments), 
#               size = line.width, linetype = guide.linetype, color = 'black') +
#    geom_vline(aes(xintercept = xintercept, alpha = onset.alpha * show), data = conc$vlines, 
#               size = line.width, color = 'black') +
#    geom_curve(aes(x = x, xend = xend, y = y, yend = yend, alpha = effect.alpha * show), data = conc$arrows, 
#               curvature = effect.curvature, color = 'black', size = line.width, arrow = arrow(length = unit(.2, 'cm'))) + 
#    geom_curve(aes(x = x, xend = xend, y = y, yend = yend, alpha = ci.alpha * show), data = conc$ub.arrows, na.rm=TRUE,
#               curvature = effect.curvature, color = 'black', size = line.width, arrow = arrow(length = unit(.2, 'cm'))) + 
#    geom_curve(aes(x = x, xend = xend, y = y, yend = yend, alpha = ci.alpha * show), data = conc$lb.arrows, na.rm=TRUE, 
#               curvature = effect.curvature, color = 'black', size = line.width, arrow = arrow(length = unit(.2, 'cm')))
  
  # plot spaghetti if there is any
  if(nrow(conc$spaghetti.labels) > 0) {
    p = p + geom_text(aes(x=x, y=y, label = unit, alpha = spaghetti.label.alpha * show), data = conc$spaghetti.labels,
                      color='black', size=spaghetti.label.size) + 
      geom_line(aes(x=x, y=y, group = unit, alpha = spaghetti.line.alpha * show),  data = conc$spaghetti.lines,
                color='black', size=spaghetti.line.width) 
  }
  
  # facet if we want multiple facets
  if (!all(conc$lines$facet == conc$lines$facet[1])) {
    if (facet.vertical) { p = p + facet_grid(facet ~ ., scales = 'free_y') }
    else { p = p + facet_grid(. ~ facet) }
  }
  # if only one estimate per facet, exclude estimate-denoting linetype from legend
  if (is.null(facet)) { p = p + guides(linetype = 'none') }
  
  # if timesteps(Y) is a date, the x coordinates in our plot are as.numeric(timesteps(Y))
  # that's in units of days since the unix epoch: 1970-01-01
  # to improve readability, display ticks as Dates
  p = tryCatch({
    as.Date(colnames(attr(estimates[[1]], 'setup')$Y))
    p + scale_x_continuous(labels = function(time) { as.Date(time, origin = '1970-01-01') })
  }, error = function(e) { p })
  
  p + xlab('') + ylab('') + labs(color = '', fill = '') + scale_alpha(guide = 'none') +
    theme_bw() + theme(legend.direction = "horizontal", legend.position = "none",
                          axis.title = element_blank(),
                          text = element_text(size = 18))
}